Simple Examples

This tutorial goes through a few common ML tasks using the cremi dataset and a 2D U-Net.

Introduction and overview

In this tutorial we will cover a few basic ML tasks using the DaCapo toolbox. We will:

  • Prepare a dataloader for the CREMI dataset

  • Train a simple 2D U-Net for both instance and semantic segmentation

  • Visualize the results

Environment setup

If you have not already done so, you will need to install DaCapo. You can do this by first creating a new environment and then installing the DaCapo Toolbox.

I highly recommend using uv for environment management, but there are many tools to choose from.

uv init
uv add git+https://github.com/pattonw/dacapo-toolbox.git

Data Preparation

DaCapo works with zarr, so we will download CREMI Sample A and save it as a zarr file.

import wget
from pathlib import Path
import dask

dask.config.set(scheduler="single-threaded")

# Download some cremi data
# immediately convert it to zarr for convenience
if not Path("sample_A_20160501.hdf").exists():
    wget.download(
        "https://cremi.org/static/data/sample_C_20160501.hdf", "sample_C_20160501.hdf"
    )
    wget.download(
        "https://cremi.org/static/data/sample_A_20160501.hdf", "sample_A_20160501.hdf"
    )

Data Loading

We will use the funlib.persistence library to interface with zarr. This library adds support for units, voxel size, and axis names along with the ability to query our data based on a Roi object describing a specific rectangular piece of data. This is especially useful in a microscopy context where you regularly need to chunk your data for processing.

import numpy as np
from funlib.persistence import prepare_ds, open_ds
import h5py
from pathlib import Path
import re
if not Path("cremi.zarr/train/raw").exists():
    test = h5py.File("sample_C_20160501.hdf", "r")
    raw_data = test["volumes/raw"][:]
    labels_data = test["volumes/labels/neuron_ids"][:]
    test_raw = prepare_ds(
        "cremi.zarr/test/raw",
        raw_data.shape,
        voxel_size=(40, 4, 4),
        dtype=raw_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    test_raw[test_raw.roi] = raw_data
    test_labels = prepare_ds(
        "cremi.zarr/test/labels",
        labels_data.shape,
        voxel_size=(40, 4, 4),
        dtype=labels_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    test_labels[test_labels.roi] = labels_data
    train = h5py.File("sample_A_20160501.hdf", "r")
    raw_data = train["volumes/raw"][:]
    labels_data = train["volumes/labels/neuron_ids"][:]
    train_raw = prepare_ds(
        "cremi.zarr/train/raw",
        raw_data.shape,
        voxel_size=(40, 4, 4),
        dtype=raw_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    train_raw[train_raw.roi] = raw_data
    train_labels = prepare_ds(
        "cremi.zarr/train/labels",
        labels_data.shape,
        voxel_size=(40, 4, 4),
        dtype=labels_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    train_labels[train_labels.roi] = labels_data
else:
    train_raw = open_ds("cremi.zarr/train/raw")
    train_labels = open_ds("cremi.zarr/train/labels")
    test_raw = open_ds("cremi.zarr/test/raw")
    test_labels = open_ds("cremi.zarr/test/labels")

Lets visualize our train and test data

# a custom label color map for showing instances
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.animation as animation
from IPython.display import HTML
import matplotlib as mpl

mpl.rcParams["animation.embed_limit"] = 50_000_000  # 50 MB, for example

# Create a custom label color map for showing instances
np.random.seed(1)
colors = [[0, 0, 0]] + [list(np.random.choice(range(256), size=3)) for _ in range(255)]
label_cmap = ListedColormap(colors)

Training data

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

ims = []
for i, (x, y) in enumerate(zip(train_raw.data, train_labels.data)):
    # Show the raw data
    if i == 0:
        im = axes[0].imshow(x)
        axes[0].set_title("Raw Train Data")
        im2 = axes[1].imshow(
            y % 256, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
        )
        axes[1].set_title("Train Labels")
    else:
        im = axes[0].imshow(x, animated=True)
        im2 = axes[1].imshow(
            y % 256,
            cmap=label_cmap,
            vmin=0,
            vmax=255,
            animated=True,
            interpolation="none",
        )
    ims.append([im, im2])

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
HTML(video_html)
WARNING:matplotlib.animation:MovieWriter stderr:
Received > 3 system signals, hard exiting
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:224, in AbstractMovieWriter.saving(self, fig, outfile, dpi, *args, **kwargs)
    223 try:
--> 224     yield self
    225 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1126, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
   1125         frame_number += 1
-> 1126 writer.grab_frame(**savefig_kwargs)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:352, in MovieWriter.grab_frame(self, **savefig_kwargs)
    351 # Save the figure data to the sink, using the frame format and dpi.
--> 352 self.fig.savefig(self._proc.stdin, format=self.frame_format,
    353                  dpi=self.dpi, **savefig_kwargs)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3490, in Figure.savefig(self, fname, transparent, **kwargs)
   3489         _recursively_make_axes_transparent(stack, ax)
-> 3490 self.canvas.print_figure(fname, **kwargs)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2184, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2183     with cbook._setattr_cm(self.figure, dpi=dpi):
-> 2184         result = print_method(
   2185             filename,
   2186             facecolor=facecolor,
   2187             edgecolor=edgecolor,
   2188             orientation=orientation,
   2189             bbox_inches_restore=_bbox_inches_restore,
   2190             **kwargs)
   2191 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2040, in FigureCanvasBase._switch_canvas_and_return_print_method.<locals>.<lambda>(*args, **kwargs)
   2039     skip = optional_kws - {*inspect.signature(meth).parameters}
-> 2040     print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
   2041         *args, **{k: v for k, v in kwargs.items() if k not in skip}))
   2042 else:  # Let third-parties do as they see fit.

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:417, in FigureCanvasAgg.print_raw(self, filename_or_obj, metadata)
    416     raise ValueError("metadata not supported for raw/rgba")
--> 417 FigureCanvasAgg.draw(self)
    418 renderer = self.get_renderer()

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:382, in FigureCanvasAgg.draw(self)
    380 with (self.toolbar._wait_cursor_for_draw_cm() if self.toolbar
    381       else nullcontext()):
--> 382     self.figure.draw(self.renderer)
    383     # A GUI class may be need to update a window using this draw, so
    384     # don't forget to call the superclass.

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:94, in _finalize_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     92 @wraps(draw)
     93 def draw_wrapper(artist, renderer, *args, **kwargs):
---> 94     result = draw(artist, renderer, *args, **kwargs)
     95     if renderer._rasterizing:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69         renderer.start_filter()
---> 71     return draw(artist, renderer)
     72 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3257, in Figure.draw(self, renderer)
   3256 self.patch.draw(renderer)
-> 3257 mimage._draw_list_compositing_images(
   3258     renderer, self, artists, self.suppressComposite)
   3260 renderer.close_group('figure')

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:134, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    133     for a in artists:
--> 134         a.draw(renderer)
    135 else:
    136     # Composite any adjacent images together

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69         renderer.start_filter()
---> 71     return draw(artist, renderer)
     72 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/axes/_base.py:3210, in _AxesBase.draw(self, renderer)
   3208     _draw_rasterized(self.get_figure(root=True), artists_rasterized, renderer)
-> 3210 mimage._draw_list_compositing_images(
   3211     renderer, self, artists, self.get_figure(root=True).suppressComposite)
   3213 renderer.close_group('axes')

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:134, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    133     for a in artists:
--> 134         a.draw(renderer)
    135 else:
    136     # Composite any adjacent images together

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69         renderer.start_filter()
---> 71     return draw(artist, renderer)
     72 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:609, in _ImageBase.draw(self, renderer)
    608 else:
--> 609     im, l, b, trans = self.make_image(
    610         renderer, renderer.get_image_magnification())
    611     if im is not None:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:912, in AxesImage.make_image(self, renderer, magnification, unsampled)
    910 clip = ((self.get_clip_box() or self.axes.bbox) if self.get_clip_on()
    911         else self.get_figure(root=True).bbox)
--> 912 return self._make_image(self._A, bbox, transformed_bbox, clip,
    913                         magnification, unsampled=unsampled)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:512, in _ImageBase._make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
    509         output_alpha = _resample(  # resample alpha channel
    510             self, A[..., 3], out_shape, t)
    511     output = _resample(  # resample rgb channels
--> 512         self, _rgb_to_rgba(A[..., :3]), out_shape, t)
    513 elif np.ndim(alpha) > 0:  # Array alpha
    514     # user-specified array alpha overrides the existing alpha channel

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:600, in IPKernelApp.sigint_handler(self, *args)
    599 elif self.kernel.shell_is_blocking:
--> 600     raise KeyboardInterrupt

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

CalledProcessError                        Traceback (most recent call last)
Cell In[5], line 27
     25 ims = ims + ims[::-1]
     26 ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
---> 27 video_html = ani.to_html5_video()
     28 video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
     29 HTML(video_html)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1306, in Animation.to_html5_video(self, embed_limit)
   1302 Writer = writers[mpl.rcParams['animation.writer']]
   1303 writer = Writer(codec='h264',
   1304                 bitrate=mpl.rcParams['animation.bitrate'],
   1305                 fps=1000. / self._interval)
-> 1306 self.save(str(path), writer=writer)
   1307 # Now open and base64 encode.
   1308 vid64 = base64.encodebytes(path.read_bytes())

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1098, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
   1093     return a * np.array([r, g, b]) + 1 - a
   1095 # canvas._is_saving = True makes the draw_event animation-starting
   1096 # callback a no-op; canvas.manager = None prevents resizing the GUI
   1097 # widget (both are likewise done in savefig()).
-> 1098 with (writer.saving(self._fig, filename, dpi),
   1099       cbook._setattr_cm(self._fig.canvas, _is_saving=True, manager=None)):
   1100     if not writer._supports_transparency():
   1101         facecolor = savefig_kwargs.get('facecolor',
   1102                                        mpl.rcParams['savefig.facecolor'])

File ~/.local/share/uv/python/cpython-3.10.17-linux-x86_64-gnu/lib/python3.10/contextlib.py:153, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
    151     value = typ()
    152 try:
--> 153     self.gen.throw(typ, value, traceback)
    154 except StopIteration as exc:
    155     # Suppress StopIteration *unless* it's the same exception that
    156     # was passed to throw().  This prevents a StopIteration
    157     # raised inside the "with" statement from being suppressed.
    158     return exc is not value

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:226, in AbstractMovieWriter.saving(self, fig, outfile, dpi, *args, **kwargs)
    224     yield self
    225 finally:
--> 226     self.finish()

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:341, in MovieWriter.finish(self)
    337     _log.log(
    338         logging.WARNING if self._proc.returncode else logging.DEBUG,
    339         "MovieWriter stderr:\n%s", err)
    340 if self._proc.returncode:
--> 341     raise subprocess.CalledProcessError(
    342         self._proc.returncode, self._proc.args, out, err)

CalledProcessError: Command '['ffmpeg', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', '1200x600', '-pix_fmt', 'rgba', '-framerate', '5.0', '-loglevel', 'error', '-i', 'pipe:', '-vcodec', 'h264', '-pix_fmt', 'yuv420p', '-y', '/tmp/tmp72kydopw/temp.m4v']' returned non-zero exit status 123.
../_images/faf9e5bd39fe63e8f1d102ee6b7df1cd1c184b5c3774d2efc3f3cca53579ea99.png

Testing data

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

ims = []
for i, (x, y) in enumerate(zip(test_raw.data, test_labels.data)):
    if i == 0:
        im = axes[0].imshow(x)
        axes[0].set_title("Raw Test Data")
        im2 = axes[1].imshow(
            y % 256, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
        )
        axes[1].set_title("Test Labels")
    else:
        im = axes[0].imshow(x, animated=True)
        im2 = axes[1].imshow(
            y % 256,
            cmap=label_cmap,
            vmin=0,
            vmax=255,
            animated=True,
            interpolation="none",
        )
    ims.append([im, im2])

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
HTML(video_html)
../_images/ecf72298ccc9332cc60462f48684a6f51dd76c9ae5818f489e01635158443108.png

DaCapo

Now that we have some data, lets look at how we can use DaCapo to interface with it for some common ML use cases.

Data Split

We always want to be explicit when we define our data split for training and validation so that we are aware what data is being used for training and validation.

from dacapo_toolbox.datasplits import SimpleDataSplitConfig
datasplit = SimpleDataSplitConfig(
    name="cremi",
    path="cremi.zarr",
)
print(f"Train datasets: {datasplit.train}")
print(f"Validation datasets: {datasplit.validate}")
Train datasets: [SimpleDataset(name='train', path=PosixPath('cremi.zarr/train'), weight=1, raw_name='raw', gt_name='labels', mask_name='mask')]
Validation datasets: [SimpleDataset(name='test', path=PosixPath('cremi.zarr/test'), weight=1, raw_name='raw', gt_name='labels', mask_name='mask')]

Augmentation

We almost always want to use rotations when training in EM data. This is because the structures we care about rarely have strict orientations relative to the zyx axes. However because we usually some axial anisotropy in our data, we want to be careful about how we apply these rotations.

from dacapo_toolbox.trainers import GunpowderTrainerConfig
from dacapo_toolbox.trainers.gp_augments import ElasticAugmentConfig

# build a trainer config with elastic deformations accounting for the anisotropy
trainer = GunpowderTrainerConfig(
    name="rotations",
    augments=[
        ElasticAugmentConfig(
            control_point_spacing=(2, 20, 20),
            control_point_displacement_sigma=(2, 20, 20),
            rotation_interval=(0, 3.14),
            subsample=4,
            uniform_3d_rotation=False,  # rotate only in 2D
            augmentation_probability=0.5,
        )
    ],
)
/home/runner/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Simple Training loop

The Trainer is only useful when combined with some data, but now that we have defined some data via the DataSplitConfig and the pipeline via the TrainerConfig, we can visualize a batches:

import torch

z_slices = 13
batch_size = 3

torch_dataset = trainer.iterable_dataset(
    datasets=datasplit.train,
    input_shape=(z_slices, 128, 128),
    output_shape=(z_slices, 128, 128),
)

dataloader = torch.utils.data.DataLoader(
    torch_dataset, batch_size=batch_size, num_workers=0
)


batch = next(iter(dataloader))
fig, axes = plt.subplots(batch_size, 2, figsize=(12, 18))

ims = []
for zz in range(z_slices):
    b_ims = []
    for bb in range(batch_size):
        b_raw = batch["raw"][bb, 0, zz].numpy()
        b_labels = batch["gt"][bb, zz].numpy() % 256
        if zz == 0:
            im = axes[bb, 0].imshow(b_raw)
            im2 = axes[bb, 1].imshow(
                b_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
            )
            if bb == 0:
                axes[bb, 0].set_title("Sample Raw")
                axes[bb, 1].set_title("Sample Labels")
        else:
            im = axes[bb, 0].imshow(b_raw, animated=True)
            im2 = axes[bb, 1].imshow(
                b_labels,
                cmap=label_cmap,
                vmin=0,
                vmax=255,
                animated=True,
                interpolation="none",
            )
        b_ims.extend([im, im2])
    ims.append(b_ims)

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
HTML(video_html)
../_images/dafdf8183db0697e7763ac25fd4ad959c9b759aff417a34dc11a1af67f9e4aa5.png

Tasks

When training for instance segmentation, it is not possible to directly predict label ids since the ids have to be unique accross the full volume which is not possible to do with the local context that a UNet operates on. So instead we need to transform our labels into some intermediate representation that is both easy to predict and easy to post process. The most common method we use is a combination of affinities with optional lsds for prediction plus mutex watershed for post processing.

Next we will define the task that encapsulates this process.

from dacapo_toolbox.tasks import AffinitiesTaskConfig

affs_config = AffinitiesTaskConfig(
    name="affs",
    neighborhood=[
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 9],
        [0, 9, 0],
        [0, 0, 27],
        [0, 27, 0],
    ],
    # lsds=True,
)

torch_dataset = trainer.iterable_dataset(
    datasets=datasplit.train,
    input_shape=(z_slices, 128, 128),
    output_shape=(z_slices, 128, 128),
    task=affs_config,
)

dataloader = torch.utils.data.DataLoader(
    torch_dataset, batch_size=batch_size, num_workers=0
)

batch = next(iter(dataloader))
fig, axes = plt.subplots(batch_size, 3, figsize=(18, 18))
ims = []
for zz in range(z_slices):
    b_ims = []
    for bb in range(batch_size):
        b_raw = batch["raw"][bb, 0, zz].numpy()
        b_labels = batch["gt"][bb, zz].numpy() % 256
        b_target = batch["target"][bb, [0, 5, 6], zz].numpy()
        if zz == 0:
            im = axes[bb, 0].imshow(b_raw)
            im2 = axes[bb, 1].imshow(
                b_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
            )
            im3 = axes[bb, 2].imshow(b_target.transpose(1, 2, 0), interpolation="none")
            if bb == 0:
                axes[bb, 0].set_title("Sample Raw")
                axes[bb, 1].set_title("Sample Labels")
                axes[bb, 2].set_title("Sample Affinities")
        else:
            im = axes[bb, 0].imshow(b_raw, animated=True)
            im2 = axes[bb, 1].imshow(
                b_labels,
                cmap=label_cmap,
                vmin=0,
                vmax=255,
                animated=True,
                interpolation="none",
            )
            im3 = axes[bb, 2].imshow(
                b_target.transpose(1, 2, 0), animated=True, interpolation="none"
            )
        b_ims.extend([im, im2, im3])
    ims.append(b_ims)

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="600" height="600" ', video_html)
HTML(video_html)
../_images/863b28ce8ead3db11988f6b70887f874eab2f673158b7d5028112fc610deac53.png

Models

DaCapo lets you easily train any model you want, with a special wrapper for UNets specifically. Lets make one now.

from dacapo_toolbox.architectures import CNNectomeUNetConfig
from funlib.geometry import Coordinate, Roi

input_shape = Coordinate((5, 156, 156))

unet_config = CNNectomeUNetConfig(
    name="2.5D_UNet",
    input_shape=input_shape,
    fmaps_in=1,
    fmaps_out=32,
    num_fmaps=32,
    fmap_inc_factor=4,
    downsample_factors=[(1, 2, 2), (1, 2, 2), (1, 2, 2)],
    kernel_size_down=[
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
    ],
    kernel_size_up=[
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
        [(3, 3, 3), (3, 3, 3)],
    ],
)

output_shape = unet_config.compute_output_shape(input_shape)
print(f"Given an input of shape {input_shape} we get an out of shape {output_shape}")
Given an input of shape (5, 156, 156) we get an out of shape (1, 64, 64)

Training loop

Now we can bring everything together and train our model.

dataset = trainer.iterable_dataset(
    datasets=datasplit.train,
    input_shape=input_shape,
    output_shape=output_shape,
    task=affs_config,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=3,
    prefetch_factor=2,
    persistent_workers=True,
)


task = affs_config.task_type(affs_config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this ensures we output the appropriate number of channels, use the appropriate final activation etc.
module = task.create_model(unet_config).to(device)
loss = task.loss
optimizer = torch.optim.Adam(module.parameters(), lr=1e-4)

losses = []

print(f"Training on {device}")
for iteration, batch in enumerate(iter(dataloader)):
    raw, target, weight = (
        batch["raw"].to(device),
        batch["target"].to(device),
        batch["weight"].to(device),
    )
    optimizer.zero_grad()
    output = module(raw)
    loss_value = loss.compute(output, target, weight)
    loss_value.backward()
    optimizer.step()
    print(f"Loss ({iteration}): {loss_value.item():.3f}")

    losses.append(loss_value.item())

    if iteration >= 400:
        break
Training on cpu
Loss (0): 0.657
Loss (1): 0.959
Loss (2): 0.629
Loss (3): 0.789
Loss (4): 0.816
Loss (5): 0.646
Loss (6): 0.604
Loss (7): 0.507
Loss (8): 0.566
Loss (9): 0.658
Loss (10): 0.659
Loss (11): 0.588
Loss (12): 0.519
Loss (13): 0.551
Loss (14): 0.479
Loss (15): 0.579
Loss (16): 0.555
Loss (17): 0.474
Loss (18): 0.502
Loss (19): 0.500
Loss (20): 0.389
Loss (21): 0.454
Loss (22): 0.469
Loss (23): 0.546
Loss (24): 0.425
Loss (25): 0.438
Loss (26): 0.400
Loss (27): 0.398
Loss (28): 0.401
Loss (29): 0.398
Loss (30): 0.431
Loss (31): 0.350
Loss (32): 0.394
Loss (33): 0.387
Loss (34): 0.382
Loss (35): 0.408
Loss (36): 0.442
Loss (37): 0.357
Loss (38): 0.409
Loss (39): 0.395
Loss (40): 0.424
Loss (41): 0.393
Loss (42): 0.386
Loss (43): 0.427
Loss (44): 0.346
Loss (45): 0.377
Loss (46): 0.358
Loss (47): 0.301
Loss (48): 0.322
Loss (49): 0.298
Loss (50): 0.319
Loss (51): 0.303
Loss (52): 0.399
Loss (53): 0.332
Loss (54): 0.265
Loss (55): 0.331
Loss (56): 0.337
Loss (57): 0.456
Loss (58): 0.336
Loss (59): 0.347
Loss (60): 0.311
Loss (61): 0.369
Loss (62): 0.303
Loss (63): 0.328
Loss (64): 0.330
Loss (65): 0.286
Loss (66): 0.320
Loss (67): 0.275
Loss (68): 0.375
Loss (69): 0.273
Loss (70): 0.248
Loss (71): 0.296
Loss (72): 0.275
Loss (73): 0.372
Loss (74): 0.308
Loss (75): 0.280
Loss (76): 0.280
Loss (77): 0.275
Loss (78): 0.293
Loss (79): 0.382
Loss (80): 0.332
Loss (81): 0.280
Loss (82): 0.271
Loss (83): 0.362
Loss (84): 0.256
Loss (85): 0.235
Loss (86): 0.256
Loss (87): 0.311
Loss (88): 0.391
Loss (89): 0.267
Loss (90): 0.319
Loss (91): 0.356
Loss (92): 0.374
Loss (93): 0.334
Loss (94): 0.279
Loss (95): 0.289
Loss (96): 0.230
Loss (97): 0.253
Loss (98): 0.465
Loss (99): 0.212
Loss (100): 0.247
Loss (101): 0.317
Loss (102): 0.279
Loss (103): 0.282
Loss (104): 0.368
Loss (105): 0.324
Loss (106): 0.315
Loss (107): 0.363
Loss (108): 0.262
Loss (109): 0.253
Loss (110): 0.340
Loss (111): 0.328
Loss (112): 0.258
Loss (113): 0.215
Loss (114): 0.304
Loss (115): 0.331
Loss (116): 0.298
Loss (117): 0.235
Loss (118): 0.226
Loss (119): 0.270
Loss (120): 0.247
Loss (121): 0.248
Loss (122): 0.241
Loss (123): 0.262
Loss (124): 0.265
Loss (125): 0.269
Loss (126): 0.367
Loss (127): 0.244
Loss (128): 0.285
Loss (129): 0.307
Loss (130): 0.242
Loss (131): 0.297
Loss (132): 0.307
Loss (133): 0.377
Loss (134): 0.268
Loss (135): 0.351
Loss (136): 0.307
Loss (137): 0.251
Loss (138): 0.253
Loss (139): 0.265
Loss (140): 0.223
Loss (141): 0.205
Loss (142): 0.284
Loss (143): 0.294
Loss (144): 0.243
Loss (145): 0.274
Loss (146): 0.205
Loss (147): 0.243
Loss (148): 0.229
Loss (149): 0.225
Loss (150): 0.405
Loss (151): 0.265
Loss (152): 0.251
Loss (153): 0.262
Loss (154): 0.263
Loss (155): 0.236
Loss (156): 0.194
Loss (157): 0.194
Loss (158): 0.281
Loss (159): 0.165
Loss (160): 0.173
Loss (161): 0.215
Loss (162): 0.221
Loss (163): 0.215
Loss (164): 0.264
Loss (165): 0.278
Loss (166): 0.194
Loss (167): 0.250
Loss (168): 0.217
Loss (169): 0.191
Loss (170): 0.170
Loss (171): 0.217
Loss (172): 0.257
Loss (173): 0.239
Loss (174): 0.203
Loss (175): 0.222
Loss (176): 0.214
Loss (177): 0.248
Loss (178): 0.209
Loss (179): 0.166
Loss (180): 0.200
Loss (181): 0.202
Loss (182): 0.262
Loss (183): 0.233
Loss (184): 0.184
Loss (185): 0.277
Loss (186): 0.301
Loss (187): 0.197
Loss (188): 0.205
Loss (189): 0.317
Loss (190): 0.221
Loss (191): 0.264
Loss (192): 0.184
Loss (193): 0.288
Loss (194): 0.193
Loss (195): 0.229
Loss (196): 0.280
Loss (197): 0.194
Loss (198): 0.223
Loss (199): 0.287
Loss (200): 0.173
Loss (201): 0.189
Loss (202): 0.538
Loss (203): 0.439
Loss (204): 0.398
Loss (205): 0.340
Loss (206): 0.267
Loss (207): 0.251
Loss (208): 0.282
Loss (209): 0.372
Loss (210): 0.304
Loss (211): 0.280
Loss (212): 0.365
Loss (213): 0.271
Loss (214): 0.273
Loss (215): 0.232
Loss (216): 0.197
Loss (217): 0.196
Loss (218): 0.253
Loss (219): 0.237
Loss (220): 0.154
Loss (221): 0.271
Loss (222): 0.240
Loss (223): 0.344
Loss (224): 0.298
Loss (225): 0.201
Loss (226): 0.221
Loss (227): 0.276
Loss (228): 0.207
Loss (229): 0.259
Loss (230): 0.208
Loss (231): 0.240
Loss (232): 0.224
Loss (233): 0.175
Loss (234): 0.449
Loss (235): 0.330
Loss (236): 0.305
Loss (237): 0.154
Loss (238): 0.178
Loss (239): 0.233
Loss (240): 0.147
Loss (241): 0.305
Loss (242): 0.257
Loss (243): 0.207
Loss (244): 0.213
Loss (245): 0.164
Loss (246): 0.145
Loss (247): 0.188
Loss (248): 0.285
Loss (249): 0.222
Loss (250): 0.175
Loss (251): 0.208
Loss (252): 0.142
Loss (253): 0.190
Loss (254): 0.215
Loss (255): 0.240
Loss (256): 0.248
Loss (257): 0.177
Loss (258): 0.136
Loss (259): 0.175
Loss (260): 0.235
Loss (261): 0.264
Loss (262): 0.150
Loss (263): 0.235
Loss (264): 0.233
Loss (265): 0.171
Loss (266): 0.213
Loss (267): 0.218
Loss (268): 0.153
Loss (269): 0.116
Loss (270): 0.335
Loss (271): 0.162
Loss (272): 0.193
Loss (273): 0.190
Loss (274): 0.207
Loss (275): 0.218
Loss (276): 0.214
Loss (277): 0.181
Loss (278): 0.269
Loss (279): 0.208
Loss (280): 0.190
Loss (281): 0.235
Loss (282): 0.169
Loss (283): 0.224
Loss (284): 0.249
Loss (285): 0.196
Loss (286): 0.257
Loss (287): 0.239
Loss (288): 0.243
Loss (289): 0.138
Loss (290): 0.172
Loss (291): 0.223
Loss (292): 0.237
Loss (293): 0.176
Loss (294): 0.147
Loss (295): 0.155
Loss (296): 0.212
Loss (297): 0.094
Loss (298): 0.230
Loss (299): 0.128
Loss (300): 0.421
Loss (301): 0.155
Loss (302): 0.238
Loss (303): 0.146
Loss (304): 0.186
Loss (305): 0.206
Loss (306): 0.218
Loss (307): 0.189
Loss (308): 0.295
Loss (309): 0.130
Loss (310): 0.201
Loss (311): 0.181
Loss (312): 0.136
Loss (313): 0.148
Loss (314): 0.269
Loss (315): 0.214
Loss (316): 0.189
Loss (317): 0.239
Loss (318): 0.207
Loss (319): 0.174
Loss (320): 0.224
Loss (321): 0.199
Loss (322): 0.235
Loss (323): 0.168
Loss (324): 0.186
Loss (325): 0.126
Loss (326): 0.223
Loss (327): 0.167
Loss (328): 0.179
Loss (329): 0.172
Loss (330): 0.166
Loss (331): 0.177
Loss (332): 0.157
Loss (333): 0.236
Loss (334): 0.109
Loss (335): 0.233
Loss (336): 0.191
Loss (337): 0.161
Loss (338): 0.218
Loss (339): 0.183
Loss (340): 0.151
Loss (341): 0.159
Loss (342): 0.123
Loss (343): 0.091
Loss (344): 0.181
Loss (345): 0.178
Loss (346): 0.180
Loss (347): 0.075
Loss (348): 0.259
Loss (349): 0.178
Loss (350): 0.198
Loss (351): 0.318
Loss (352): 0.130
Loss (353): 0.188
Loss (354): 0.124
Loss (355): 0.153
Loss (356): 0.176
Loss (357): 0.110
Loss (358): 0.148
Loss (359): 0.186
Loss (360): 0.141
Loss (361): 0.102
Loss (362): 0.149
Loss (363): 0.239
Loss (364): 0.167
Loss (365): 0.225
Loss (366): 0.155
Loss (367): 0.149
Loss (368): 0.240
Loss (369): 0.159
Loss (370): 0.167
Loss (371): 0.196
Loss (372): 0.139
Loss (373): 0.226
Loss (374): 0.190
Loss (375): 0.124
Loss (376): 0.127
Loss (377): 0.288
Loss (378): 0.198
Loss (379): 0.148
Loss (380): 0.271
Loss (381): 0.176
Loss (382): 0.186
Loss (383): 0.201
Loss (384): 0.166
Loss (385): 0.082
Loss (386): 0.133
Loss (387): 0.127
Loss (388): 0.164
Loss (389): 0.188
Loss (390): 0.259
Loss (391): 0.112
Loss (392): 0.128
Loss (393): 0.119
Loss (394): 0.147
Loss (395): 0.192
Loss (396): 0.301
Loss (397): 0.143
Loss (398): 0.125
Loss (399): 0.158
Loss (400): 0.158
plt.plot(losses)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.show()
../_images/59f68075ba145813c56182136300c4cf3a4d58ca079cd2987655078f085f984c.png
import mwatershed as mws

# Lets predict on some validation data:
val_raw, val_gt = datasplit.validate[0].raw, datasplit.validate[0].gt
# fetch a xy slice from the center of our validation volume
# We snap to grid to a multiple of the max downsampling factor of
# the unet (1, 8, 8) to ensure downsampling is always possible
roi = val_raw.roi
z_coord = Coordinate(1, 0, 0)
xy_coord = Coordinate(0, 1, 1)
center_offset = roi.center * z_coord + roi.offset * xy_coord
center_size = val_raw.voxel_size * z_coord + (roi.shape * xy_coord) // 2
center_slice = Roi(center_offset, center_size)
center_slice = center_slice.snap_to_grid(val_raw.voxel_size * Coordinate(1, 8, 8))
context = (input_shape - output_shape) // 2 * val_raw.voxel_size

# Read the raw data
raw_input = val_raw.to_ndarray(center_slice.grow(context, context))

# Predict on the validation data
with torch.no_grad():
    device = torch.device("cpu")
    module = module.to(device)
    pred = (
        module(torch.from_numpy(raw_input).to(device).unsqueeze(0).unsqueeze(0))
        .cpu()
        .detach()
        .numpy()
    )
# Plot the results
fig, axes = plt.subplots(1, 4, figsize=(24, 8))
padding = (input_shape - output_shape) // 2

# select the long range affinity channels for visualization
prediction = pred[0, [0, 5, 6], 0]

# Run mutex watershed on the affinity predictions.
# We subtract 0.5 to move affs from range (0, 1) to (-0.5, 0.5).
# This is because mutex only splits objects on negative edges.
pred_labels = (
    mws.agglom(pred[0].astype(np.float64) - 0.5, offsets=affs_config.neighborhood)[0]
    % 256
)

# read the ground truth labels
gt_labels = val_gt.to_ndarray(center_slice)[0] % 256

# Pad
prediction = np.pad(
    prediction,
    ((0,), (padding[1],), (padding[2],)),
    mode="constant",
    constant_values=np.nan,
)
pred_labels = np.pad(
    pred_labels,
    ((padding[1],), (padding[2],)),
    mode="constant",
    constant_values=0,
)
gt_labels = np.pad(
    gt_labels,
    ((padding[1],), (padding[2],)),
    mode="constant",
    constant_values=0,
)

# Plot the results
im_raw = axes[0].imshow(raw_input[2])
im2 = axes[1].imshow(gt_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none")
im4 = axes[2].imshow(prediction.transpose(1, 2, 0), interpolation="none")
im5 = axes[3].imshow(
    pred_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
)
axes[0].set_title("Val Raw")
axes[1].set_title("Val Labels")
axes[2].set_title("Pred Affinities")
axes[3].set_title("Pred Labels")
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-5.069539..12.011909].
../_images/1d5da79a696f5b20ae3954b9714e6561496437d2bf59aef75cff8a63ec026ac9.png